[Model] Add Olmo3 model implementation#24534
Conversation
Signed-off-by: Shane A <shanea@allenai.org>
Signed-off-by: Shane A <shanea@allenai.org>
Signed-off-by: Shane A <shanea@allenai.org>
Signed-off-by: Shane A <shanea@allenai.org>
Signed-off-by: Shane A <shanea@allenai.org>
| layer_idx = extract_layer_index(prefix) | ||
| sliding_window = (self.config.sliding_window | ||
| if self.config.layer_types[layer_idx] | ||
| == "sliding_attention" else None) | ||
| self.attn = Attention( | ||
| self.num_heads, | ||
| self.head_dim, | ||
| self.scaling, | ||
| num_kv_heads=self.num_kv_heads, | ||
| cache_config=vllm_config.cache_config, | ||
| quant_config=vllm_config.quant_config, | ||
| per_layer_sliding_window=sliding_window, | ||
| prefix=f"{prefix}.attn", | ||
| ) | ||
|
|
||
| # Rotary embeddings. Rope scaling is only applied on full attention | ||
| # layers. | ||
| self.rope_scaling = (self.config.rope_scaling | ||
| if sliding_window is None else None) | ||
| self.rotary_emb = get_rope( | ||
| self.head_dim, | ||
| rotary_dim=self.head_dim, | ||
| max_position=self.max_position_embeddings, | ||
| base=self.rope_theta, # type: ignore | ||
| rope_scaling=self.rope_scaling, | ||
| ) |
There was a problem hiding this comment.
Please correct me if I'm wrong. But seems the only difference between Olmo2 and Olmo3 is the introduction of sliding window? If so, I think we can simply modify Olmo2's attention implementation to make it fit both Olmo2 and Olmo3.
There was a problem hiding this comment.
Yep that's right. I'll try to merge the Olmo3 logic into the existing Olmo2 code.
There was a problem hiding this comment.
61b1b26 I've updated the Olmo2 logic to support both Olmo2 and Olmo3. The sliding window settings are new and so I've kept the Olmo3Config class.
The transformers PR is out (huggingface/transformers#40778), but transformers folks haven't reviewed or expressed yet if they also want Olmo3 to be part of Olmo2. This vllm implementation should hopefully be compatible with the transformers implementation regardless of which option they choose.
There was a problem hiding this comment.
They're content with Olmo3 being a separate model implementation to Olmo2.
There was a problem hiding this comment.
I think how Transformers organizes modeling file for olmo3 won't be an issue for us (they may request using modular modeling to inherit from olmo2 to create a separated class), as long as the config doesn't change very much.
Signed-off-by: Shane A <shanea@allenai.org>
There was a problem hiding this comment.
We can keep the config fork temporarily before Transformers PR merged. Then we can clean this up after that.
There was a problem hiding this comment.
Yep, that's what I had in mind! My local testing indicates that this PR works for transformers version with and without Olmo3.
| trust_remote_code=True), | ||
| "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), | ||
| "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), | ||
| "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), |
There was a problem hiding this comment.
Need to add is_available_online=False (if the repo isn't available yet) and/or min_transformers_version (if the model isn't supported by the current version)
There was a problem hiding this comment.
9542485 The problem was that I put olmo3 instead of olmo2 in the registry, which I have fixed in this commit. The above two solutions don't apply since the repo is public and this implementation is intended to work even before the transformers implementation is released.
Signed-off-by: Shane A <shanea@allenai.org>
Head branch was pushed to by a user without write access
Signed-off-by: Shane A <shanea@allenai.org> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Shane A <shanea@allenai.org> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: bbartels <benjamin@bartels.dev>
Signed-off-by: Shane A <shanea@allenai.org> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Shane A <shanea@allenai.org> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Purpose
This PR adds the implementation for the upcoming Olmo 3 model. The HF implementation is being added concurrently, so the PR includes the config too.
Test Plan
The test plan is to see that basic generation (via
examples/offline_inference/basic/generate.py) produces sensible output. I cannot run HF vs vLLM (in a shareable manner) because the HF implementation is being added concurrently. Nevertheless, I used a custom script to do HF vs vLLM and saw only minors errors (that would eventually propagate to be larger) with identical output.Test Result
Result of running
examples/offline_inference/basic/generate.py:Excerpt of diff between HF and vLLM activations (using a custom script).
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.